import matplotlib.pyplot as plt
import seaborn as sns; sns.set_style('darkgrid')
import numpy as np
import torch
import pandas as pd
from torchvision import datasets, transforms
import torch.nn.functional as F
from sudoku import MNISTSudokuSolver
import tqdm

# Test performance of MNIST classifier
device = 0
test_loader = torch.utils.data.DataLoader(
            datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ])),
            batch_size=500, shuffle=False, num_workers=4, pin_memory=True)
model = MNISTSudokuSolver(3,300,600)
def test(model):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = (target - 1) % 10
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_acc = 100. * correct / len(test_loader.dataset)
    return test_acc
for seed in range(10):
  results = {}
  for o in [True, False]:
    mnist_accs = []
    for it in tqdm.trange(1,101):
      file = './cloud_logs/seed%d-permTrue-mnistTrue-boardSz3-aux300-m600-lr0.002-bsz40-outputMask%s/it%d.pth' % (seed,o,it)
      model.load_state_dict(torch.load(file, map_location='cpu'))
      test_acc = test(model.digit_convnet.to(device))
      mnist_accs.append(test_acc)
    if o:
      results['output_masked'] = mnist_accs
    else:
      results['original'] = mnist_accs
  torch.save(results, 'mnist_acc_seed%d.dict' % seed)
  print("seed %d done" % seed)

# Generate Table 1 Results
for m in [True, False]:
  for o in [True, False]:
    for train in ['train', 'test']:
      accs = np.zeros(10)
      for seed in range(1,11):
        file = './cloud_logs/seed%d-permTrue-mnist%s-boardSz3-aux300-m600-lr0.002-bsz40-outputMask%s/%s.csv' % (seed,m,o,train)
        result = pd.read_csv(file)
        acc = (1.0 - result['err'].values[-1]) * 100.0
        accs[seed-1] = acc
      print("visual:%s, output_masked:%s, train:%s, mean:%.1f, std:%.1f" % (m,o,train,accs.mean(),accs.std(ddof=1)/np.sqrt(10)))
      print(accs) # seeds 2 and 7 succeeded

# Generate Plot (Figure 3)
plt.figure(figsize=(9,5))
color_palette = sns.color_palette()
for seed in [2,1]:
  if seed==2:
    plt.subplot(221)
    plt.axvline(9, color=color_palette[3], linestyle='--')
    plt.title('Visual Sudoku (Successful)')
  else:
    plt.subplot(222)
    plt.title('Visual Sudoku (Unsuccessful)')
  plt.ylabel('Training Acc (%)')
  plt.xlabel('Epochs')
  plt.ylim(0,100)
  file = './cloud_logs/seed%d-permTrue-mnistTrue-boardSz3-aux300-m600-lr0.002-bsz40-outputMaskFalse/train.csv' % seed
  result = pd.read_csv(file)
  sudoku_accs = (1.0 - result['err'].values)*100
  sns.lineplot(result['epoch'].values, sudoku_accs, color=color_palette[0])
  if seed==2:
    plt.subplot(223)
    plt.axvline(9, color=color_palette[3], linestyle='--')
  else:
    plt.subplot(224)
  plt.title('MNIST')
  plt.ylabel('Test Acc (%)')
  plt.xlabel('Epochs')
  plt.ylim(0,100)
  file_2 = "mnist_acc_seed%d.dict" % seed
  mnist_accs = torch.load(file_2)['original']
  sns.lineplot(result['epoch'].values, mnist_accs, color=color_palette[2])
  
  print(sudoku_accs, mnist_accs)
plt.tight_layout()
# plt.show()
plt.savefig('symbol_grounding.pdf')